import numpy as np
import pandas as pd
import statsmodels.api as sm
import math
import sys

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate


def get_cov_term(dataset):
        dataset["bhat_bucket"] = dataset.bhat.round(2)
        mu_epsilon_dict = dataset.groupby('bhat_bucket').epsilonhat.mean().to_dict()
        mu_eta_dict = dataset.groupby('bhat_bucket').etahat.mean().to_dict()
        dataset['mu_epsilon'] = dataset.bhat_bucket.map(mu_epsilon_dict).astype(float)
        dataset['mu_eta'] = dataset.bhat_bucket.map(mu_eta_dict).astype(float)
        dataset['eta_epsilon_product'] = dataset['mu_epsilon'] * dataset['mu_eta']	
        dataset["bhat_weight_bucket"] = dataset.bhat_weight.round(2)
        mu_epsilon_dict_weight = dataset.groupby('bhat_weight_bucket').epsilonhat_weight.mean().to_dict()
        mu_eta_dict_weight = dataset.groupby('bhat_weight_bucket').etahat_weight.mean().to_dict()
        dataset['mu_epsilon_weight'] = dataset.bhat_weight_bucket.map(mu_epsilon_dict_weight).astype(float)
        dataset['mu_eta_weight'] = dataset.bhat_weight_bucket.map(mu_eta_dict_weight).astype(float)
        dataset['eta_epsilon_product_weight'] = dataset['mu_epsilon_weight'] * dataset['mu_eta_weight']
        overall_cov = ((dataset.eta_epsilon_product - (dataset.mu_eta.mean()*dataset.mu_epsilon.mean())).mean())/(dataset.bhat.mean() * (1-(dataset.bhat.mean())))
        overall_cov_weight = ((dataset.eta_epsilon_product_weight - (dataset.mu_eta_weight.mean()*dataset.mu_epsilon_weight.mean())).mean())/(dataset.bhat_weight.mean() * (1-(dataset.bhat_weight.mean())))
        return overall_cov, overall_cov_weight


## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited', overall_cov = 0.0):
	black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
	non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
	est =  black_aud_rate - non_black_aud_rate + (overall_cov)
	return est

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
        seMultiplier=getSEMultiplier(dataset,pbvarb)
        #seReg = regEstimate(dataset,pbvarb,outcome)[1]
        seChen = seReg*seMultiplier
        return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))


### Operational audits and predicted race probabilitiess
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final_new_eic.csv")
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]

nc_BISG = dataBISG[dataBISG['state']=="NC"]
nc_BISG['pprob_black'] = nc_BISG['predicted_prob_black']
nc_BISG = nc_BISG[['taxpayer_id', 'pprob_black']]

##NC data
ncdata = pd.read_csv("/REDACTED/disparity_data/NC_Analysis_Dataset_2014_tplevel.csv")

nc_weight = pd.read_stata('/REDACTED/BIFSG/nc_rewt_2014_coded_output_v4_December_ncsamp.dta')


## gibbs prob
dataGibbs = pd.read_csv("/REDACTED/data/clean/rhBIFSGIncMars.csv")
dataGibbs = dataGibbs[["taxpayer_id", "black"]]
dataGibbs= dataGibbs.drop_duplicates(subset=['taxpayer_id'], keep='first')

# geography-only imputations
geo_only = pd.read_csv('/REDACTED/disparity_data/geo_only_fully_merged_test.csv')
geo_only = geo_only[geo_only.race_by_cbg_black.notnull()]




##########################################################
#### Table A.10
##########################################################

sys.stdout = open("/REDACTED/disp_diff_samples_test2.txt", "w")

ncdata = ncdata[ncdata.black_ind.notna()]
nc_weight = nc_weight[['taxpayer_id', 'black_prob', 'uswgt']]

ncdata = pd.merge(ncdata,nc_weight ,on='taxpayer_id')

ncdata = pd.merge(ncdata,nc_BISG,on='taxpayer_id')
ncdata.drop(['predicted_prob_black', 'black_prob'], axis = 1, inplace = True)
ncdata['predicted_prob_black'] = ncdata['pprob_black'] 
ncdata.drop(['pprob_black'], axis = 1, inplace = True)
ncdata['p_black_rd'] = ncdata['predicted_prob_black'].round(2)

ncdata['unitwt']=1
#ncdata = ncdata[ncdata.black_prob.notna()]
#ncdata['p_black_rd'] = ncdata['black_prob'].round(2)
ncdata['aud_no_research_audits'] = [1 if (x.find('[80]') == -1)
                         and (x.find(' 80]') == -1)
                         and (x.find('[80 ') == -1)
                         and (x.find('[91]') == -1)
                         and (x.find(' 91]') == -1)
                         and (x.find('[91 ') == -1)
                         and y == 1
                         else 0
                         for x, y in zip(ncdata.audit_source_code.astype(str), ncdata.audited)]

ncdata['aud_no_research_audits']=pd.to_numeric(ncdata['aud_no_research_audits'])
#ncdata['black_prob']=pd.to_numeric(ncdata['black_prob'])
ncdata['p_black_rd']=pd.to_numeric(ncdata['p_black_rd'])
ncdata['uswgt']=pd.to_numeric(ncdata['uswgt'])
ncdata['unitwt']=pd.to_numeric(ncdata['unitwt'])
ncdata['black_ind']=pd.to_numeric(ncdata['black_ind'])


data = dataBISG.merge(dataGibbs, how = 'left', on = 'taxpayer_id')

data_baseline = dataBISG[["taxpayer_id", "aud_no_research_audits","predicted_prob_black", "case", "isEIC"]]
data_fsg = data_baseline.loc[data_baseline['case'] == "BIFSG"]
data_gibbs = data[["taxpayer_id", "aud_no_research_audits", "black", "isEIC"]]
data_gibbs.dropna(inplace = True)


eitc_baseline = data_baseline.loc[data_baseline['isEIC'] == 1]
eitc_fsg = data_fsg.loc[data_fsg['isEIC'] == 1]
eitc_gibbs = data_gibbs.loc[data_gibbs['isEIC'] == 1]

non_eitc_baseline = data_baseline.loc[data_baseline['isEIC'] == 0]
non_eitc_fsg = data_fsg.loc[data_fsg['isEIC'] == 0]
non_eitc_gibbs = data_gibbs.loc[data_gibbs['isEIC'] == 0]


print("\n\n Full Pop")
print("\n Baseline")
## baseline
linDisparity = regEstimate(data_baseline, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
probDisparity = chenEstimate(data_baseline, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
linSE, probSE = getSEs(data_baseline, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits', seReg = float(linDisparity[1]))
print(linDisparity)
print(probDisparity)
print(linSE)
print(probSE)

print("\n BIFSG only")
## fsg only
linDisparity_fsg = regEstimate(data_fsg, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
probDisparity_fsg = chenEstimate(data_fsg, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
linSE_fsg, probSE_fsg = getSEs(data_fsg, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_fsg[1]))
print(linDisparity_fsg)
print(probDisparity_fsg)
print(linSE_fsg)
print(probSE_fsg)

print("\n Gibbs")
## gibbs
linDisparity_gibbs = regEstimate(data_gibbs, pbvarb = 'black', outcome = 'aud_no_research_audits')
probDisparity_gibbs = chenEstimate(data_gibbs, pbvarb = 'black', outcome = 'aud_no_research_audits')
linSE_gibbs, probSE_gibbs = getSEs(data_gibbs, pbvarb = 'black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_gibbs[1]))
print(linDisparity_gibbs)
print(probDisparity_gibbs)
print(linSE_gibbs)
print(probSE_gibbs)



print("\n\n EITC Pop")
print("\n Baseline")
## baseline
linDisparity_eitc = regEstimate(eitc_baseline, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
probDisparity_eitc = chenEstimate(eitc_baseline, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
linSE_eitc, probSE_eitc = getSEs(eitc_baseline, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_eitc[1]))
print(linDisparity_eitc)
print(probDisparity_eitc)
print(linSE_eitc)
print(probSE_eitc)

print("\n BIFSG only")
## fsg only
linDisparity_fsg_eitc = regEstimate(eitc_fsg, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
probDisparity_fsg_eitc = chenEstimate(eitc_fsg, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
linSE_fsg_eitc, probSE_fsg_eitc = getSEs(eitc_fsg, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_fsg_eitc[1]))
print(linDisparity_fsg_eitc)
print(probDisparity_fsg_eitc)
print(linSE_fsg_eitc)
print(probSE_fsg_eitc)

print("\n Gibbs")
## gibbs
linDisparity_gibbs_eitc = regEstimate(eitc_gibbs, pbvarb = 'black', outcome = 'aud_no_research_audits')
probDisparity_gibbs_eitc = chenEstimate(eitc_gibbs, pbvarb = 'black', outcome = 'aud_no_research_audits')
linSE_gibbs_eitc, probSE_gibbs_eitc = getSEs(eitc_gibbs, pbvarb = 'black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_gibbs_eitc[1]))
print(linDisparity_gibbs_eitc)
print(probDisparity_gibbs_eitc)
print(linSE_gibbs_eitc)
print(probSE_gibbs_eitc)

print("\n\n Non EITC Pop")
print("\n Baseline")
## baseline
linDisparity_non_eitc = regEstimate(non_eitc_baseline, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
probDisparity_non_eitc = chenEstimate(non_eitc_baseline, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
linSE_non_eitc, probSE_non_eitc = getSEs(non_eitc_baseline, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_non_eitc[1]))
print(linDisparity_non_eitc)
print(probDisparity_non_eitc)
print(linSE_non_eitc)
print(probSE_non_eitc)

print("\n BIFSG only")
## fsg only
linDisparity_fsg_non_eitc = regEstimate(non_eitc_fsg, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
probDisparity_fsg_non_eitc = chenEstimate(non_eitc_fsg, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')
linSE_fsg_non_eitc, probSE_fsg_non_eitc = getSEs(non_eitc_fsg, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_fsg_non_eitc[1]))
print(linDisparity_fsg_non_eitc)
print(probDisparity_fsg_non_eitc)
print(linSE_fsg_non_eitc)
print(probSE_fsg_non_eitc)

print("\n Gibbs")
## gibbs
linDisparity_gibbs_non_eitc = regEstimate(non_eitc_gibbs, pbvarb = 'black', outcome = 'aud_no_research_audits')
probDisparity_gibbs_non_eitc = chenEstimate(non_eitc_gibbs, pbvarb = 'black', outcome = 'aud_no_research_audits')
linSE_gibbs_non_eitc, probSE_gibbs_non_eitc = getSEs(non_eitc_gibbs, pbvarb = 'black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_gibbs_non_eitc[1]))
print(linDisparity_gibbs_non_eitc)
print(probDisparity_gibbs_non_eitc)
print(linSE_gibbs_non_eitc)
print(probSE_gibbs_non_eitc)


#### Recalibration

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        print("number of obs :" + str(model.nobs))
        return coef, se, model


## recalibrate prob_black
B_b_model = sm.WLS.from_formula("black_ind ~ predicted_prob_black", weights = ncdata.unitwt, data = ncdata).fit(cov_type = "HC1")
B_b_model_weight = sm.WLS.from_formula("black_ind ~ predicted_prob_black", weights = ncdata.uswgt, data = ncdata).fit(cov_type = "HC1")


### Full population
dataBISG['predicted_prob_black']=pd.to_numeric(dataBISG['predicted_prob_black'])

dataBISG['bhat'] = B_b_model.predict(dataBISG['predicted_prob_black'])
dataBISG['bhat_weight'] = B_b_model_weight.predict(dataBISG['predicted_prob_black'])
dataBISG['epsilonhat'] = B_b_model.resid
dataBISG['epsilonhat_weight'] = B_b_model_weight.resid
### Linear Estimator
lin_disp = regEstimate(dataBISG, pbvarb = 'bhat', outcome = 'aud_no_research_audits')
lin_disp_weight = regEstimate(dataBISG, pbvarb = 'bhat_weight', outcome = 'aud_no_research_audits')   
dataBISG['etahat'] = lin_disp[2].resid
dataBISG['etahat_weight'] = lin_disp_weight[2].resid
## overall covariance
overall_cov = get_cov_term(dataBISG)
### Probabilistic Estimator
prob_disp = chenEstimate(dataBISG, pbvarb = 'bhat', outcome = 'aud_no_research_audits', overall_cov = overall_cov[0])
prob_disp_weight = chenEstimate(dataBISG, pbvarb = 'bhat_weight', outcome = 'aud_no_research_audits', overall_cov = overall_cov[1])
### Standard errors
lin_se, prob_se = getSEs(dataBISG, pbvarb='bhat', outcome='aud_no_research_audits', seReg = float(lin_disp[1]))
lin_se_weight, prob_se_weight = getSEs(dataBISG, pbvarb='bhat_weight', outcome='aud_no_research_audits', seReg = float(lin_disp_weight[1]))



### EITC population

eitc_pop = dataBISG.loc[dataBISG['isEIC'] == 1]
non_eitc_pop = dataBISG.loc[dataBISG['isEIC'] == 0]


eitc_pop['bhat'] = B_b_model.predict(eitc_pop['predicted_prob_black'])
eitc_pop['bhat_weight'] = B_b_model_weight.predict(eitc_pop['predicted_prob_black'])
eitc_pop['epsilonhat'] = B_b_model.resid
eitc_pop['epsilonhat_weight'] = B_b_model_weight.resid
### Linear Estimator
lin_disp_eic = regEstimate(eitc_pop, pbvarb = 'bhat', outcome = 'aud_no_research_audits')
lin_disp_weight_eic = regEstimate(eitc_pop, pbvarb = 'bhat_weight', outcome = 'aud_no_research_audits')   
eitc_pop['etahat'] = lin_disp_eic[2].resid
eitc_pop['etahat_weight'] = lin_disp_weight_eic[2].resid
## overall covariance
overall_cov_eic = get_cov_term(eitc_pop)
### Probabilistic Estimator
prob_disp_eic = chenEstimate(eitc_pop, pbvarb = 'bhat', outcome = 'aud_no_research_audits', overall_cov = overall_cov_eic[0])
prob_disp_weight_eic = chenEstimate(eitc_pop, pbvarb = 'bhat_weight', outcome = 'aud_no_research_audits', overall_cov = overall_cov_eic[1])
### Standard errors
lin_se_eic, prob_se_eic = getSEs(eitc_pop, pbvarb='bhat', outcome='aud_no_research_audits', seReg = float(lin_disp_eic[1]))
lin_se_weight_eic, prob_se_weight_eic = getSEs(eitc_pop, pbvarb='bhat_weight', outcome='aud_no_research_audits', seReg = float(lin_disp_weight_eic[1]))



### Non-EITC population


non_eitc_pop['bhat'] = B_b_model.predict(non_eitc_pop['predicted_prob_black'])
non_eitc_pop['bhat_weight'] = B_b_model_weight.predict(non_eitc_pop['predicted_prob_black'])
non_eitc_pop['epsilonhat'] = B_b_model.resid
non_eitc_pop['epsilonhat_weight'] = B_b_model_weight.resid
### Linear Estimator
lin_disp_non_eic = regEstimate(non_eitc_pop, pbvarb = 'bhat', outcome = 'aud_no_research_audits')
lin_disp_weight_non_eic = regEstimate(non_eitc_pop, pbvarb = 'bhat_weight', outcome = 'aud_no_research_audits')   
non_eitc_pop['etahat'] = lin_disp_non_eic[2].resid
non_eitc_pop['etahat_weight'] = lin_disp_weight_non_eic[2].resid
## overall covariance
overall_cov_non_eic = get_cov_term(non_eitc_pop)
### Probabilistic Estimator
prob_disp_non_eic = chenEstimate(non_eitc_pop, pbvarb = 'bhat', outcome = 'aud_no_research_audits', overall_cov = overall_cov_non_eic[0])
prob_disp_weight_non_eic = chenEstimate(non_eitc_pop, pbvarb = 'bhat_weight', outcome = 'aud_no_research_audits', overall_cov = overall_cov_non_eic[1])
### Standard errors
lin_se_non_eic, prob_se_non_eic = getSEs(non_eitc_pop, pbvarb='bhat', outcome='aud_no_research_audits', seReg = float(lin_disp_non_eic[1]))
lin_se_weight_non_eic, prob_se_weight_non_eic = getSEs(non_eitc_pop, pbvarb='bhat_weight', outcome='aud_no_research_audits', seReg = float(lin_disp_weight_non_eic[1]))



### Print output

print("\nRecalibration Analysis\n")

print("\nFull Population unweighted\n")
print("Linear")
print(lin_disp[0])
print(lin_se)
print("Probabilistic")
print(prob_disp)
print(prob_se)

print("\nFull Population weighted\n")
print("Linear")
print(lin_disp_weight[0])
print(lin_se_weight)
print("Probabilistic")
print(prob_disp_weight)
print(prob_se_weight)

print("\nEITC Population unweighted\n")
print("Linear")
print(lin_disp_eic[0])
print(lin_se_eic)
print("Probabilistic")
print(prob_disp_eic)
print(prob_se_eic)

print("\nEITC Population weighted\n")
print("Linear")
print(lin_disp_weight_eic[0])
print(lin_se_weight_eic)
print("Probabilistic")
print(prob_disp_weight_eic)
print(prob_se_weight_eic)

print("\nNon-EITC Population unweighted\n")
print("Linear")
print(lin_disp_non_eic[0])
print(lin_se_non_eic)
print("Probabilisitc")
print(prob_disp_non_eic)
print(prob_se_non_eic)

print("\nNon-EITC Population weighted\n")
print("Linear")
print(lin_disp_weight_non_eic[0])
print(lin_se_weight_non_eic)
print("Probabilistic")
print(prob_disp_weight_non_eic)
print(prob_se_weight_non_eic)

### Geography-Only Imputations

# full population
linDisparity = regEstimate(geo_only, pbvarb = 'race_by_cbg_black', outcome = 'aud_no_research_audits')
probDisparity = chenEstimate(geo_only, pbvarb = 'race_by_cbg_black', outcome = 'aud_no_research_audits')
linSE, probSE = getSEs(geo_only, pbvarb = 'race_by_cbg_black', outcome = 'aud_no_research_audits', seReg = float(linDisparity[1]))

full_pop_overall_lin = linDisparity[0]
full_pop_overall_prob = probDisparity

# eitc population
eitc_geo_only = geo_only.loc[geo_only['isEIC'] == 1]

linDisparity_eic = regEstimate(eitc_geo_only, pbvarb = 'race_by_cbg_black', outcome = 'aud_no_research_audits')
probDisparity_eic = chenEstimate(eitc_geo_only, pbvarb = 'race_by_cbg_black', outcome = 'aud_no_research_audits')
linSE_eic, probSE_eic = getSEs(eitc_geo_only, pbvarb = 'race_by_cbg_black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_eic[1]))

eitc_overall_lin = linDisparity_eic[0]
eitc_overall_prob = probDisparity_eic

# non-eitc population
non_eitc_geo_only = geo_only.loc[geo_only['isEIC'] == 0]

linDisparity_non_eic = regEstimate(non_eitc_geo_only, pbvarb = 'race_by_cbg_black', outcome = 'aud_no_research_audits')
probDisparity_non_eic = chenEstimate(non_eitc_geo_only, pbvarb = 'race_by_cbg_black', outcome = 'aud_no_research_audits')
linSE_non_eic, probSE_non_eic = getSEs(non_eitc_geo_only, pbvarb = 'race_by_cbg_black', outcome = 'aud_no_research_audits', seReg = float(linDisparity_non_eic[1]))

non_eitc_overall_lin = linDisparity_non_eic[0]
non_eitc_overall_prob = probDisparity_non_eic


print("\nGeography-Only Imputations")
print("Full population \n")
print("Overall Linear Disparity: "+str(full_pop_overall_lin))
print("Linear Standard Error: "+str(linSE))
print("Overall Probabilistic Disparity: "+str(full_pop_overall_prob))
print("Probabilistic Standard Error: "+str(probSE))
print("N: " + str(len(geo_only))+"\n")


print("EITC population \n")
print("Overall Linear Disparity: "+str(eitc_overall_lin))
print("Linear Standard Error: "+str(linSE_eic))
print("Overall Probabilistic Disparity: "+str(eitc_overall_prob))
print("Probabilistic Standard Error: "+str(probSE_eic))
print("N: " + str(len(eitc_geo_only))+"\n")


print("Non EITC population \n")
print("Overall Linear Disparity: "+str(non_eitc_overall_lin))
print("Linear Standard Error: "+str(linSE_non_eic))
print("Overall Probabilistic Disparity: "+str(non_eitc_overall_prob))
print("Probabilistic Standard Error: "+str(probSE_non_eic))
print("N: " + str(len(non_eitc_geo_only))+"\n")

sys.stdout = sys.__stdout__
